
import torch.functional as F
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
# from .pytorch_msssim import pytorch_msssim





class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-8):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        return loss



class Ternary(nn.Module):
    def __init__(self):
        super(Ternary, self).__init__()
        patch_size = 7
        out_channels = patch_size * patch_size
        self.w = np.eye(out_channels).reshape(
            (patch_size, patch_size, 1, out_channels))
        self.w = np.transpose(self.w, (3, 2, 0, 1))
        self.w = torch.tensor(self.w).float().cuda()

    def transform(self, img):
        patches = torch.nn.functional.conv2d(img, self.w, padding=3, bias=None)
        transf = patches - img
        transf_norm = transf / torch.sqrt(0.81 + transf**2)
        return transf_norm

    def rgb2gray(self, rgb):
        r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
        gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
        return gray

    def hamming(self, t1, t2):
        dist = (t1 - t2) ** 2
        dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
        return dist_norm

    def valid_mask(self, t, padding):
        n, _, h, w = t.size()
        inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
        mask = F.pad(inner, [padding] * 4)
        return mask

    def forward(self, img0, img1):
        img0 = self.transform(self.rgb2gray(img0))
        img1 = self.transform(self.rgb2gray(img1))
        return self.hamming(img0, img1) * self.valid_mask(img0, 1)



class MeanShift(nn.Conv2d):
    def __init__(self, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False


class HuberLoss(nn.Module):

    def __init__(self, delta=1):
        super().__init__()
        self.delta = delta

    def forward(self, sr, hr):
        l1 = torch.abs(sr - hr)
        mask = l1 < self.delta

        sq_loss = .5 * (l1 ** 2)
        abs_loss = self.delta * (l1 - .5 * self.delta)

        return torch.mean(mask * sq_loss + (~mask) * (abs_loss))


class VGG(nn.Module):
    def __init__(self, loss_type):
        super(VGG, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m for m in vgg_features]
        conv_index = loss_type[-2:]
        if conv_index == '22':
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index == '33':
            self.vgg = nn.Sequential(*modules[:16])
        elif conv_index == '44':
            self.vgg = nn.Sequential(*modules[:26])
        elif conv_index == '54':
            self.vgg = nn.Sequential(*modules[:35])
        elif conv_index == 'P':
            self.vgg = nn.ModuleList([
                nn.Sequential(*modules[:8]),
                nn.Sequential(*modules[8:16]),
                nn.Sequential(*modules[16:26]),
                nn.Sequential(*modules[26:35])
            ])
        self.vgg = nn.DataParallel(self.vgg).cuda()

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229, 0.224, 0.225)
        self.sub_mean = MeanShift(vgg_mean, vgg_std)
        self.vgg.requires_grad = False
        # self.criterion = nn.L1Loss()
        self.conv_index = conv_index

    def forward(self, sr, hr):
        def _forward(x):
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x

        def _forward_all(x):
            feats = []
            x = self.sub_mean(x)
            for module in self.vgg.module:
                x = module(x)
                feats.append(x)
            return feats

        if self.conv_index == 'P':
            vgg_sr_feats = _forward_all(sr)
            with torch.no_grad():
                vgg_hr_feats = _forward_all(hr.detach())
            loss = 0
            for i in range(len(vgg_sr_feats)):
                loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i])
                # print(loss_f)
                loss += loss_f
            # print()
        else:
            vgg_sr = _forward(sr)
            with torch.no_grad():
                vgg_hr = _forward(hr.detach())
            loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss


# For Adversarial loss
class BasicBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)):
        m = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias)]
        if bn: m.append(nn.BatchNorm2d(out_channels))
        if act is not None: m.append(act)
        super(BasicBlock, self).__init__(*m)


class Discriminator(nn.Module):
    def __init__(self, args, gan_type='GAN'):
        super(Discriminator, self).__init__()

        in_channels = 3
        out_channels = 64
        depth = 7
        # bn = not gan_type == 'WGAN_GP'
        bn = True
        act = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        m_features = [
            BasicBlock(in_channels, out_channels, 3, bn=bn, act=act)
        ]
        for i in range(depth):
            in_channels = out_channels
            if i % 2 == 1:
                stride = 1
                out_channels *= 2
            else:
                stride = 2
            m_features.append(BasicBlock(
                in_channels, out_channels, 3, stride=stride, bn=bn, act=act
            ))

        self.features = nn.Sequential(*m_features)

        self.patch_size = args.patch_size
        feature_patch_size = self.patch_size // (2 ** ((depth + 1) // 2))
        # patch_size = 256 // (2**((depth + 1) // 2))
        m_classifier = [
            nn.Linear(out_channels * feature_patch_size ** 2, 1024),
            act,
            nn.Linear(1024, 1)
        ]
        self.classifier = nn.Sequential(*m_classifier)

    def forward(self, x):
        if x.size(2) != self.patch_size or x.size(3) != self.patch_size:
            midH, midW = x.size(2) // 2, x.size(3) // 2
            p = self.patch_size // 2
            x = x[:, :, (midH - p):(midH - p + self.patch_size), (midW - p):(midW - p + self.patch_size)]
        features = self.features(x)
        output = self.classifier(features.view(features.size(0), -1))

        return output


import torch.optim as optim


class Adversarial(nn.Module):
    def __init__(self, args, gan_type):
        super(Adversarial, self).__init__()
        self.gan_type = gan_type
        self.gan_k = 1  # args.gan_k
        self.discriminator = torch.nn.DataParallel(Discriminator(args, gan_type))
        if gan_type != 'WGAN_GP':
            self.optimizer = optim.Adam(
                self.discriminator.parameters(),
                betas=(0.9, 0.99), eps=1e-8, lr=1e-4
            )
        else:
            self.optimizer = optim.Adam(
                self.discriminator.parameters(),
                betas=(0, 0.9), eps=1e-8, lr=1e-5
            )
        # self.scheduler = utility.make_scheduler(args, self.optimizer)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    def forward(self, fake, real, fake_input0=None, fake_input1=None, fake_input_mean=None):
        # def forward(self, fake, real):
        fake_detach = fake.detach()
        if fake_input0 is not None:
            fake0, fake1 = fake_input0.detach(), fake_input1.detach()
        if fake_input_mean is not None:
            fake_m = fake_input_mean.detach()
        # print(fake.size(), fake_input0.size(), fake_input1.size(), fake_input_mean.size())

        self.loss = 0
        for _ in range(self.gan_k):
            self.optimizer.zero_grad()
            d_fake = self.discriminator(fake_detach)

            if fake_input0 is not None and fake_input1 is not None:
                d_fake0 = self.discriminator(fake0)
                d_fake1 = self.discriminator(fake1)
            if fake_input_mean is not None:
                d_fake_m = self.discriminator(fake_m)

            # print(d_fake.size(), d_fake0.size(), d_fake1.size(), d_fake_m.size())

            d_real = self.discriminator(real)
            if self.gan_type == 'GAN':
                label_fake = torch.zeros_like(d_fake)
                label_real = torch.ones_like(d_real)
                loss_d \
                    = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
                      + F.binary_cross_entropy_with_logits(d_real, label_real)
                if fake_input0 is not None and fake_input1 is not None:
                    loss_d += F.binary_cross_entropy_with_logits(d_fake0, label_fake) \
                              + F.binary_cross_entropy_with_logits(d_fake1, label_fake)
                if fake_input_mean is not None:
                    loss_d += F.binary_cross_entropy_with_logits(d_fake_m, label_fake)

            elif self.gan_type.find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if self.gan_type.find('GP') >= 0:
                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = self.discriminator(hat)
                    gradients = torch.autograd.grad(
                        outputs=d_hat.sum(), inputs=hat,
                        retain_graph=True, create_graph=True, only_inputs=True
                    )[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty

            # Discriminator update
            self.loss += loss_d.item()
            if self.training:
                loss_d.backward()
                self.optimizer.step()

            if self.gan_type == 'WGAN':
                for p in self.discriminator.parameters():
                    p.data.clamp_(-1, 1)

        self.loss /= self.gan_k

        d_fake_for_g = self.discriminator(fake)
        if self.gan_type == 'GAN':
            loss_g = F.binary_cross_entropy_with_logits(
                d_fake_for_g, label_real
            )
        elif self.gan_type.find('WGAN') >= 0:
            loss_g = -d_fake_for_g.mean()

        # Generator loss
        return loss_g

    def state_dict(self, *args, **kwargs):
        state_discriminator = self.discriminator.state_dict(*args, **kwargs)
        state_optimizer = self.optimizer.state_dict()

        return dict(**state_discriminator, **state_optimizer)
class Loss(nn.modules.loss._Loss):
    def __init__(self, args):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.loss = []
        self.loss_module = nn.ModuleList()
        # loss  =  args['train']['loss']
        # loss  =  args['loss']
        # weight, loss_type = loss.split('*')
        weight = 1
        loss_type = 'CharbonnierLoss'
        if loss_type == 'MSE':
            loss_function = nn.MSELoss()
        elif loss_type == 'Huber':
            loss_function = HuberLoss(delta=.5)
        elif loss_type == 'L1':
            loss_function = nn.L1Loss()
        elif loss_type.find('VGG') >= 0:
            loss_function = VGG(loss_type[3:])
        elif loss_type == 'SSIM':
            loss_function = pytorch_msssim.SSIM(val_range=1.)
        elif loss_type.find('GAN') >= 0:
            loss_function = Adversarial(args, loss_type)
        elif loss_type == 'CharbonnierLoss':
            loss_function = CharbonnierLoss()
        self.loss.append({
            'type': loss_type,
            'weight': float(weight),
            'function': loss_function}
        )
        if loss_type.find('GAN') >= 0 >= 0:
            self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.3f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        device = torch.device('cuda')
        self.loss_module.to(device)
        # self.loss_module = self.loss_module.cuda()
        #if args.precision == 'half': self.loss_module.half()
        # if args.cuda:# and args.n_GPUs > 1:
        self.loss_module = nn.DataParallel(self.loss_module)


    def forward(self, sr, hr, fake_imgs=None):
        loss = 0
        losses = {}
        for i, l in enumerate(self.loss):
            if l['function'] is not None:
                if l['type'] == 'GAN':
                    if fake_imgs is None:
                        fake_imgs = [None, None, None]
                    _loss = l['function'](sr, hr, fake_imgs[0], fake_imgs[1], fake_imgs[2])
                else:
                    _loss = l['function'](sr, hr)
                effective_loss = l['weight'] * _loss
                losses[l['type']] = effective_loss
                loss += effective_loss
            elif l['type'] == 'DIS':
                losses[l['type']] = self.loss[i - 1]['function'].loss


        return loss, losses
